30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/DebugLog.h"
33 #include "llvm/Support/InterleavedRange.h"
34 #include "llvm/Support/raw_ostream.h"
37 #define DEBUG_TYPE "linalg-transforms"
56 .Case<scf::ForOp>([&](scf::ForOp forOp) {
57 scf::ForOp partialIteration;
60 return partialIteration->getResults();
61 assert(!partialIteration &&
"expected that loop was not peeled");
62 return forOp->getResults();
71 for (
auto loopOp : loops)
84 if (!e.isFunctionOfDim(dim))
95 return llvm::interleaved(ri,
", ",
"|",
"");
146 static FailureOr<SmallVector<std::optional<int64_t>>>
150 int64_t newDim = iteratorTypes.size();
151 iteratorTypes.push_back(iteratorTypes[dim]);
154 indexingMaps.size(), std::nullopt);
156 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
158 AffineMap map = indexingMaps[operandIdx];
161 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
169 "num results invariant violation");
171 if (!maybeOperandDimensionToPack.has_value()) {
172 newMaps.push_back(map);
177 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
183 newMaps.push_back(map);
186 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
188 indexingMaps = newMaps;
190 return packedDimPerIndexingMap;
196 struct PackedOperandsDim {
202 struct PackedOperandsDimList {
203 void pushBack(PackedOperandsDim &&packedOperandsDims) {
204 spec.emplace_back(packedOperandsDims);
218 linalg::PackOp packOp,
219 bool lowerPadLikeWithInsertSlice) {
221 auto packedTensorType =
222 cast<RankedTensorType>(packOp->getResultTypes().front());
223 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
226 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
235 PackingMetadata packingMetadata = computePackingMetadata(
236 packedTensorType.getRank(), packOp.getInnerDimsPos());
250 for (
auto [pos, innerSize] :
251 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
253 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
263 rewriter, loc, map, {outerSize, origSize, innerSize});
265 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
267 packingMetadata.reassociations);
268 Value paddingValue = packOp.getPaddingValue();
270 paddingValue = arith::ConstantOp::create(
274 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
275 highs, paddingValue,
false);
277 LDBG() <<
"insertPositions: "
278 << llvm::interleaved(packingMetadata.insertPositions);
279 LDBG() <<
"outerPositions: "
280 << llvm::interleaved(packingMetadata.outerPositions);
281 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
282 LDBG() <<
"packedToStripMinedShapePerm: "
283 << llvm::interleaved(packedToStripMinedShapePerm);
284 LDBG() <<
"reassociations: "
285 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
287 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
288 LDBG() <<
"collapsed type: " << collapsed;
290 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
309 auto insertSliceOp = tensor::InsertSliceOp::create(
310 rewriter, loc, padOp, packOp.getDest(),
313 LDBG() <<
"insert_slice op: " << insertSliceOp;
315 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
323 auto expandShapeResultType =
325 auto reshapeOp = tensor::ExpandShapeOp::create(
326 rewriter, loc, expandShapeResultType, padOp.getResult(),
327 packingMetadata.reassociations);
332 auto transposeOp = linalg::TransposeOp::create(
333 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
335 LDBG() <<
"reshape op: " << reshapeOp;
336 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
337 LDBG() <<
"transpose op: " << transposeOp;
340 rewriter.
replaceOp(packOp, transposeOp->getResults());
345 FailureOr<LowerUnPackOpResult>
347 bool lowerUnpadLikeWithExtractSlice) {
352 RankedTensorType packedTensorType = unPackOp.getSourceType();
353 int64_t packedRank = packedTensorType.getRank();
356 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
357 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
366 auto extractSliceOp = tensor::ExtractSliceOp::create(
367 rewriter, loc, destTensorType, unPackOp.getSource(),
371 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
374 nullptr, extractSliceOp};
379 PackingMetadata packingMetadata;
389 RankedTensorType stripMinedTensorType =
391 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
392 stripMinedTensorType, packingMetadata.reassociations);
399 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
400 stripMinedTensorType.getElementType());
402 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
403 packedToStripMinedShapePerm);
405 LDBG() <<
"insertPositions: "
406 << llvm::interleaved(packingMetadata.insertPositions);
407 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
408 LDBG() <<
"packedToStripMinedShapePerm: "
409 << llvm::interleaved(packedToStripMinedShapePerm);
410 LDBG() <<
"reassociations: "
411 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
413 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
414 LDBG() <<
"collapsed type: " << collapsedType;
417 auto reshapeOp = tensor::CollapseShapeOp::create(
418 rewriter, loc, collapsedType, transposeOp->getResult(0),
419 packingMetadata.reassociations);
422 int64_t destRank = destTensorType.getRank();
423 auto extractSliceOp = tensor::ExtractSliceOp::create(
424 rewriter, loc, destTensorType, reshapeOp->getResult(0),
430 auto copyOp = linalg::CopyOp::create(
431 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
434 rewriter.
replaceOp(unPackOp, copyOp->getResults());
440 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
442 for (
auto &i : spec) {
443 if (!i.packedDimForEachOperand[operandPos].has_value())
445 res.push_back(i.packedDimForEachOperand[operandPos].value());
451 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
453 for (
auto &i : spec) {
454 if (!i.packedDimForEachOperand[operandPos].has_value())
456 res.push_back(i.packedSize);
465 linalg::LinalgOp linalgOp,
467 if (packedSizes.size() != linalgOp.getNumLoops()) {
469 "incorrect number of pack sizes");
475 linalgOp.getIteratorTypesArray();
476 LDBG() <<
"Start packing: " << linalgOp;
477 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
478 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
483 PackedOperandsDimList listOfPackedOperandsDim;
484 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
487 if (maybeConstant.has_value() && maybeConstant.value() == 0)
490 PackedOperandsDim packedOperandsDims;
491 packedOperandsDims.packedSize = packedSizes[i];
492 FailureOr<SmallVector<std::optional<int64_t>>>
493 maybePackedDimForEachOperand =
495 if (
failed(maybePackedDimForEachOperand))
497 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
498 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
500 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
501 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
502 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
503 LDBG() <<
"packedDimForEachOperand: "
504 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
510 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
512 for (
const auto &operandsList : {inputOperands, initOperands}) {
513 for (
OpOperand *opOperand : operandsList) {
514 int64_t pos = opOperand->getOperandNumber();
515 Value operand = opOperand->get();
517 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
519 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
520 LDBG() <<
"operand: " << operand;
521 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
522 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
523 if (innerPackSizes.empty()) {
524 inputsAndInits.push_back(operand);
527 Value dest = linalg::PackOp::createDestinationTensor(
528 rewriter, loc, operand, innerPackSizes, innerPos,
530 ShapedType operandType = cast<ShapedType>(operand.
getType());
531 bool areConstantTiles =
535 if (areConstantTiles && operandType.hasStaticShape() &&
536 !linalg::PackOp::requirePaddingValue(
537 operandType.getShape(), innerPos,
538 cast<ShapedType>(dest.
getType()).getShape(), {},
540 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
541 innerPos, innerPackSizes));
547 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
548 packOps.push_back(linalg::PackOp::create(
549 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
551 inputsAndInits.push_back(packOps.back());
557 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
559 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
560 auto packedLinalgOp =
561 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
562 inputs, inits, indexingMaps, iteratorTypes);
563 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
566 for (
OpResult result : packedLinalgOp->getResults()) {
567 int64_t resultNum = result.getResultNumber();
568 linalg::PackOp maybePackedInit =
569 inits[resultNum].getDefiningOp<linalg::PackOp>();
570 if (!maybePackedInit) {
571 results.push_back(result);
575 unPackOps.push_back(linalg::UnPackOp::create(
576 rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
577 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
578 results.push_back(unPackOps.back());
586 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
615 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
619 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
621 assert(tensorType == transposedValue.
getType() &&
622 "expected tensor type mismatch");
627 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
631 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
635 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
641 auto transposedGenericOp = linalg::GenericOp::create(
645 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
646 operandsRef.take_front(linalgOp.getNumDpsInputs()),
647 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
649 linalgOp.getIteratorTypesArray());
650 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
651 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
653 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
656 FailureOr<PackTransposeResult>
658 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
665 linalg::PackOp transposedPackOp =
666 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
668 if (!packOp.getResult().hasOneUse())
671 OpOperand &packUse = *packOp->getUses().begin();
672 if (packUse.
getOwner() != linalgOp) {
674 linalgOp,
"not a single use by the LinalgOp target");
677 (!linalgOp.isDpsInit(&packUse) ||
678 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
680 "not produced by the LinalgOp target");
686 int64_t numLeadingDims = packOp.getSourceRank();
687 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
691 if (permutation.empty())
692 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
694 if (innerPerm.empty()) {
697 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
699 llvm::append_range(permutation,
700 llvm::map_range(innerPerm, [&](int64_t pos) {
701 return numLeadingDims + pos;
713 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
716 linalg::UnPackOp transposedUnPackOp;
719 transposedLinalgOp->getOpOperand(packUseOperandNumber);
720 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
722 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
723 rewriter, loc, transposedResult, innerPerm, outerPerm);
725 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
729 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
747 FailureOr<PackResult>
752 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
753 assert((mnkPaddedSizesNextMultipleOf.empty() ||
754 mnkPaddedSizesNextMultipleOf.size() == 3) &&
755 "num of packing sizes next multiple should be empty or of size 3");
756 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
759 int64_t numLoops = linalgOp.getNumLoops();
761 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
762 <<
" in: " << linalgOp;
764 linalgOp,
"need 3+ loops to find a matmul to pack");
768 int64_t numPackedDims = mnkPackedSizes.size();
770 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
771 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
773 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
774 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
776 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
777 paddedSizesNextMultipleOf[mnkOrder[i]] =
778 mnkPaddedSizesNextMultipleOf.empty() ? 0
779 : mnkPaddedSizesNextMultipleOf[i];
783 FailureOr<ContractionDimensions> maybeDimensions =
785 if (
failed(maybeDimensions)) {
786 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
788 "couldn't infer matmul iterators");
796 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
797 kPos = maybeDimensions->k.back();
798 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
799 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
802 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
804 FailureOr<GenericOp> generalizeResult =
806 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
807 genericOp = *generalizeResult;
815 LDBG() <<
"perm: " << llvm::interleaved(permutation);
818 FailureOr<GenericOp> interchangeResult =
820 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
821 genericOp = *interchangeResult;
822 LDBG() <<
"Generalized Op to pack: " << genericOp;
839 cast<LinalgOp>(genericOp.getOperation())
840 .createLoopRanges(rewriter, genericOp.getLoc());
844 LDBG() <<
"paddedSizesNextMultipleOf: "
845 << llvm::interleaved(paddedSizesNextMultipleOf);
846 LDBG() <<
"loopRanges: "
847 << llvm::interleaved(
848 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
851 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
852 if (paddedSizesNextMultipleOf[i] == 0) {
853 adjustedPackedSizes.push_back(packedSizes[i]);
860 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
861 {loopRanges[adjustedPackedSizes.size()].size,
862 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
864 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
870 return pack(rewriter, genericOp, adjustedPackedSizes);
879 assert(!tileSizeComputationFunction &&
"tile sizes already set");
884 &op->getParentOfType<func::FuncOp>().getBody().front());
885 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
903 auto padValue = padOp.getConstantPaddingValue();
906 if (padValue.getParentBlock() == &padOp.getRegion().front())
908 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
912 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
913 padOp.getResultType(), dynSizes);
916 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
925 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
928 rewriter, padOp.getLoc(),
929 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
933 auto resultType = padOp.getResultType();
937 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
938 if (resultType.isDynamicDim(dim)) {
940 padOp.getSource(), dim));
943 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
945 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
946 dynSizes.push_back(plusHigh);
948 staticSizes.push_back(resultType.getDimSize(dim));
953 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
954 resultType.getElementType(), dynSizes);
958 auto sourceType = padOp.getSourceType();
966 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
974 if (!sliceOp.hasUnitStride())
977 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
981 bool zeroSliceGuard =
true;
983 if (std::optional<bool> control = controlFn(sliceOp))
984 zeroSliceGuard = *control;
989 FailureOr<TilingResult> tilingResult =
991 sliceOp.getMixedSizes(), zeroSliceGuard);
995 RankedTensorType sourceType = sliceOp.getSourceType();
996 RankedTensorType resultType = sliceOp.getResultType();
1000 if (sourceType.getRank() == resultType.getRank()) {
1001 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1007 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1009 rewriter.
replaceOp(sliceOp, rankReduced);
1019 linalg::PackOp packOp) {
1020 Value input = packOp.getSource();
1021 if (!packOp.getPaddingValue()) {
1025 assert(llvm::all_of(packOp.getAllOuterDims(),
1026 [](int64_t val) { return val == 1; }) &&
1027 "some outer dims are != 1");
1030 ShapedType inputType = packOp.getSourceType();
1031 int64_t inputRank = inputType.getRank();
1034 packOp.getDimAndTileMapping();
1041 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1044 if (!tileAndPosMapping.count(dimIdx)) {
1045 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1046 assert(inputDimSize == 1 &&
1047 "with all outer dims == 1, this non-tiled input dim should be 1!");
1048 paddedShape.push_back(inputDimSize);
1055 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1059 if (cstTileSize.has_value()) {
1060 paddedShape.push_back(cstTileSize.value());
1065 paddedShape.push_back(ShapedType::kDynamic);
1068 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1073 false, loc, builder,
1083 constexpr int64_t kNonTiledMarker = -1;
1088 vec, [&](int64_t v) {
return v != kNonTiledMarker; });
1103 int64_t unpackedRank = shape.size();
1104 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1106 innerDims.push_back(dim++);
1111 outerDims.push_back(dim++);
1119 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1124 rankReducedOuterDimsPerm =
1126 if (!rankReducedOuterDimsPerm.empty())
1127 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1130 perm.append(innerDims);
1139 if (llvm::any_of(packOp.getAllOuterDims(),
1140 [](int64_t dim) { return dim != 1; })) {
1142 packOp,
"not all outer dimensions of the result are 1s");
1151 packOp.getDimAndTileMapping();
1152 int64_t srcRank = packOp.getSourceRank();
1153 int64_t destRank = packOp.getDestRank();
1154 int64_t numTiles = destRank - srcRank;
1160 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1161 if (dimAndTileMapping.count(i)) {
1165 auto [_, tileSize] =
1167 tileSizes.push_back(tileSize);
1180 for (int64_t i = 0; i < srcRank; i++) {
1188 if (llvm::is_contained(innerDimPos, i))
1190 srcPermForTranspose.push_back(i);
1192 srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
1194 LDBG() <<
"Pack permutation: " << packOp;
1195 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1200 transShapeForEmptyOp.append(tileSizes);
1202 applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1203 srcPermForTranspose);
1205 tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
1206 packOp.getSourceType().getElementType());
1209 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1210 srcPermForTranspose);
1221 for (
auto tileSize : packOp.getMixedTiles()) {
1222 auto [tileSizeStatic, tileSizeOfr] =
1224 writeSizes.push_back(tileSizeOfr);
1225 writeShape.push_back(tileSizeStatic);
1229 auto insert = tensor::InsertSliceOp::create(
1230 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
1231 writeOffsets, writeSizes, writeStrides);
1232 rewriter.
replaceOp(packOp, insert.getResult());
1239 int64_t srcRank = unpackOp.getSourceRank();
1240 int64_t destRank = unpackOp.getDestRank();
1243 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1244 [](int64_t dim) { return dim != 1; })) {
1247 "require the tiled outer dimensions of the result are all 1s");
1253 Value source = unpackOp.getSource();
1255 unpackOp.getDimAndTileMapping();
1278 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1287 if (dimAndTileMapping.count(i)) {
1288 extractSliceSizes.push_back(oneIdxAttr);
1294 if (ShapedType::isDynamic(srcShape[i])) {
1296 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1297 extractSliceSizes.push_back(dynamicDim);
1298 shapeForEmptyOp.push_back(dynamicDim);
1300 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1301 if (srcShape[i] != 1)
1302 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1306 if (srcShape[i] != 1) {
1307 readShapeForExtractSlice.push_back(srcShape[i]);
1312 auto mixedTiles = unpackOp.getMixedTiles();
1313 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1314 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1318 auto tileShape = srcShape.drop_front(destRank);
1320 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1321 Type elemType = unpackOp.getSourceType().getElementType();
1323 Value innerTile = tensor::ExtractSliceOp::create(
1324 rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
1325 extractSliceSizes, extractSliceStrides);
1329 srcShape.take_front(destRank),
innerDimsPos, unpackOp.getOuterDimsPerm());
1332 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1335 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1337 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1341 int numLoops = shapeForEmptyOp.size();
1346 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1347 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1348 tileSizes.push_back(
1353 tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
1354 tileOffsets, tileSizes, tileStrides);
1360 for (
int i = 0, idx = 0; i < destRank; ++i) {
1361 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1362 writeSizes.push_back(tileSizes[idx++]);
1364 writeSizes.push_back(oneIdxAttr);
1366 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1367 unpackOp.getDest(), writeOffsets,
1368 writeSizes, writeStrides);
1369 rewriter.
replaceOp(unpackOp, insert.getResult());
1382 template <
typename Conv2DOp,
typename Conv1DOp>
1385 if (convOp.hasPureBufferSemantics())
1388 Value input = convOp.getInputs().front();
1389 Value kernel = convOp.getInputs().back();
1390 Value output = convOp.getOutputs().front();
1392 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1393 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1394 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1396 auto kernelShape = kernelType.getShape();
1397 auto outputShape = outputType.getShape();
1400 auto [khIndex, kwIndex, ohIndex, owIndex] =
1403 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1404 return std::make_tuple(0, 1, 1, 2);
1406 .Case([&](linalg::Conv2DNchwFchwOp op) {
1407 return std::make_tuple(2, 3, 2, 3);
1409 .Case([&](linalg::PoolingNhwcSumOp op) {
1410 return std::make_tuple(0, 1, 1, 2);
1412 .Case([&](linalg::PoolingNchwSumOp op) {
1413 return std::make_tuple(0, 1, 2, 3);
1415 .Case([&](linalg::PoolingNhwcMaxOp op) {
1416 return std::make_tuple(0, 1, 1, 2);
1418 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1419 return std::make_tuple(0, 1, 1, 2);
1421 .Case([&](linalg::PoolingNhwcMinOp op) {
1422 return std::make_tuple(0, 1, 1, 2);
1424 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1425 return std::make_tuple(0, 1, 1, 2);
1427 .Case([&](linalg::PoolingNchwMaxOp op) {
1428 return std::make_tuple(0, 1, 2, 3);
1431 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1432 return std::make_tuple(0, 0, 0, 0);
1437 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1438 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1439 bool removeH = (khSize == 1 && ohSize == 1);
1440 bool removeW = (kwSize == 1 && owSize == 1);
1441 if (!removeH && !removeW)
1447 RankedTensorType newInputType =
1448 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1449 RankedTensorType newKernelType =
1450 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1451 RankedTensorType newOutputType =
1452 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1457 rewriter, loc, input, newInputType);
1459 rewriter, loc, kernel, newKernelType);
1461 rewriter, loc, output, newOutputType);
1466 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1467 strides.erase(strides.begin() + (removeH ? 0 : 1));
1471 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1472 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1475 auto conv1DOp = Conv1DOp::create(
1476 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1477 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1481 rewriter, loc, conv1DOp.getResult(0), output);
1498 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1502 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1506 FailureOr<DepthwiseConv1DNwcWcOp>
1509 if (convOp.hasPureBufferSemantics())
1512 Value input = convOp.getInputs().front();
1513 Value kernel = convOp.getInputs().back();
1514 Value output = convOp.getOutputs().front();
1516 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1517 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1518 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1520 auto kernelShape = kernelType.getShape();
1521 auto outputShape = outputType.getShape();
1525 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1526 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1527 bool removeH = (khSize == 1 && ohSize == 1);
1528 bool removeW = (kwSize == 1 && owSize == 1);
1529 if (!removeH && !removeW)
1535 RankedTensorType newInputType =
1536 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1537 RankedTensorType newKernelType =
1538 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1539 RankedTensorType newOutputType =
1540 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1545 rewriter, loc, input, newInputType);
1547 rewriter, loc, kernel, newKernelType);
1549 rewriter, loc, output, newOutputType);
1553 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1554 strides.erase(strides.begin() + (removeH ? 0 : 1));
1558 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1559 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1562 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1563 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1564 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1568 rewriter, loc, conv1DOp.getResult(0), output);
1577 if (convOp.hasPureBufferSemantics())
1580 Value input = convOp.getInputs().front();
1581 Value kernel = convOp.getInputs().back();
1582 Value output = convOp.getOutputs().front();
1584 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1585 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1586 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1588 auto kernelShape = kernelType.getShape();
1589 auto outputShape = outputType.getShape();
1593 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1594 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1595 bool removeH = (khSize == 1 && ohSize == 1);
1596 bool removeW = (kwSize == 1 && owSize == 1);
1597 if (!removeH && !removeW)
1603 RankedTensorType newInputType =
1604 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1605 RankedTensorType newKernelType =
1606 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1607 RankedTensorType newOutputType =
1608 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1613 rewriter, loc, input, newInputType);
1615 rewriter, loc, kernel, newKernelType);
1617 rewriter, loc, output, newOutputType);
1620 Conv1DOp::create(rewriter, loc, newOutputType,
1625 rewriter, loc, conv1DOp.getResult(0), output);
1644 PoolingNwcMaxUnsignedOp>,
1647 PoolingNwcMinUnsignedOp>,
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
ArrayRef< int64_t > ReassociationIndicesRef
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of rank-reduced.
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.