32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
41#define DEBUG_TYPE "linalg-transforms"
60 .Case([&](scf::ForOp forOp) {
61 scf::ForOp partialIteration;
64 return partialIteration->getResults();
65 assert(!partialIteration &&
"expected that loop was not peeled");
66 return forOp->getResults();
75 for (
auto loopOp : loops)
88 if (!e.isFunctionOfDim(dim))
99 return llvm::interleaved(ri,
", ",
"|",
"");
150static FailureOr<SmallVector<std::optional<int64_t>>>
154 int64_t newDim = iteratorTypes.size();
155 iteratorTypes.push_back(iteratorTypes[dim]);
158 indexingMaps.size(), std::nullopt);
160 for (
int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
162 AffineMap map = indexingMaps[operandIdx];
165 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
173 "num results invariant violation");
175 if (!maybeOperandDimensionToPack.has_value()) {
176 newMaps.push_back(map);
181 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
187 newMaps.push_back(map);
190 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
192 indexingMaps = newMaps;
194 return packedDimPerIndexingMap;
200struct PackedOperandsDim {
201 OpFoldResult packedSize;
202 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
206struct PackedOperandsDimList {
207 void pushBack(PackedOperandsDim &&packedOperandsDims) {
208 spec.emplace_back(packedOperandsDims);
211 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
213 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
216 SmallVector<PackedOperandsDim> spec;
222 linalg::PackOp packOp,
223 bool lowerPadLikeWithInsertSlice) {
225 if (!packOp.hasPureTensorSemantics())
229 auto packedTensorType =
230 cast<RankedTensorType>(packOp->getResultTypes().front());
231 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
234 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
243 PackingMetadata packingMetadata;
257 for (
auto [pos, innerSize] :
258 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
260 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
270 rewriter, loc, map, {outerSize, origSize, innerSize});
272 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
274 packingMetadata.reassociations);
275 Value paddingValue = packOp.getPaddingValue();
277 paddingValue = arith::ConstantOp::create(
281 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
282 highs, paddingValue,
false);
284 LDBG() <<
"insertPositions: "
285 << llvm::interleaved(packingMetadata.insertPositions);
286 LDBG() <<
"outerPositions: "
287 << llvm::interleaved(packingMetadata.outerPositions);
288 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
289 LDBG() <<
"packedToStripMinedShapePerm: "
290 << llvm::interleaved(packedToStripMinedShapePerm);
291 LDBG() <<
"reassociations: "
292 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
294 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
295 LDBG() <<
"collapsed type: " << collapsed;
297 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
316 auto insertSliceOp = tensor::InsertSliceOp::create(
317 rewriter, loc, padOp, packOp.getDest(),
320 LDBG() <<
"insert_slice op: " << insertSliceOp;
322 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
330 auto expandShapeResultType =
332 auto reshapeOp = tensor::ExpandShapeOp::create(
333 rewriter, loc, expandShapeResultType, padOp.getResult(),
334 packingMetadata.reassociations);
339 auto transposeOp = linalg::TransposeOp::create(
340 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
342 LDBG() <<
"reshape op: " << reshapeOp;
343 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
344 LDBG() <<
"transpose op: " << transposeOp;
347 rewriter.
replaceOp(packOp, transposeOp->getResults());
352FailureOr<LowerUnPackOpResult>
354 bool lowerUnpadLikeWithExtractSlice) {
356 if (!unPackOp.hasPureTensorSemantics())
363 auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
364 int64_t packedRank = packedTensorType.getRank();
367 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
368 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
377 auto extractSliceOp = tensor::ExtractSliceOp::create(
378 rewriter, loc, destTensorType, unPackOp.getSource(),
382 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
385 nullptr, extractSliceOp,
391 PackingMetadata packingMetadata;
401 RankedTensorType stripMinedTensorType =
403 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
404 stripMinedTensorType, packingMetadata.reassociations);
411 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
412 stripMinedTensorType.getElementType());
414 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
415 packedToStripMinedShapePerm);
417 LDBG() <<
"insertPositions: "
418 << llvm::interleaved(packingMetadata.insertPositions);
419 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
420 LDBG() <<
"packedToStripMinedShapePerm: "
421 << llvm::interleaved(packedToStripMinedShapePerm);
422 LDBG() <<
"reassociations: "
423 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
425 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
426 LDBG() <<
"collapsed type: " << collapsedType;
429 auto reshapeOp = tensor::CollapseShapeOp::create(
430 rewriter, loc, collapsedType, transposeOp->getResult(0),
431 packingMetadata.reassociations);
434 int64_t destRank = destTensorType.getRank();
435 auto extractSliceOp = tensor::ExtractSliceOp::create(
436 rewriter, loc, destTensorType, reshapeOp->getResult(0),
442 auto copyOp = linalg::CopyOp::create(
443 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
446 rewriter.
replaceOp(unPackOp, copyOp->getResults());
453PackedOperandsDimList::extractPackedDimsForOperand(
int64_t operandPos) {
455 for (
auto &i : spec) {
456 if (!i.packedDimForEachOperand[operandPos].has_value())
458 res.push_back(i.packedDimForEachOperand[operandPos].value());
463SmallVector<OpFoldResult>
464PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
465 SmallVector<OpFoldResult> res;
466 for (
auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
469 res.push_back(i.packedSize);
478 linalg::LinalgOp linalgOp,
480 if (packedSizes.size() != linalgOp.getNumLoops()) {
482 "incorrect number of pack sizes");
488 linalgOp.getIteratorTypesArray();
489 LDBG() <<
"Start packing: " << linalgOp;
490 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
491 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
496 PackedOperandsDimList listOfPackedOperandsDim;
497 for (
int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
500 if (maybeConstant.has_value() && maybeConstant.value() == 0)
503 PackedOperandsDim packedOperandsDims;
504 packedOperandsDims.packedSize = packedSizes[i];
505 FailureOr<SmallVector<std::optional<int64_t>>>
506 maybePackedDimForEachOperand =
508 if (failed(maybePackedDimForEachOperand))
510 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
512 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
513 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
514 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
515 LDBG() <<
"packedDimForEachOperand: "
516 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
518 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
524 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
526 for (
const auto &operandsList : {inputOperands, initOperands}) {
527 for (
OpOperand *opOperand : operandsList) {
528 int64_t pos = opOperand->getOperandNumber();
529 Value operand = opOperand->get();
531 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
533 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
534 LDBG() <<
"operand: " << operand;
535 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
536 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
537 if (innerPackSizes.empty()) {
538 inputsAndInits.push_back(operand);
541 Value dest = linalg::PackOp::createDestinationTensor(
542 rewriter, loc, operand, innerPackSizes, innerPos,
544 ShapedType operandType = cast<ShapedType>(operand.
getType());
545 bool areConstantTiles =
549 if (areConstantTiles && operandType.hasStaticShape() &&
550 !linalg::PackOp::requirePaddingValue(
551 operandType.getShape(), innerPos,
552 cast<ShapedType>(dest.
getType()).getShape(), {},
554 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
555 innerPos, innerPackSizes));
561 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
562 packOps.push_back(linalg::PackOp::create(
563 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
565 inputsAndInits.push_back(packOps.back().getResult());
571 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
573 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
574 auto packedLinalgOp =
575 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
576 inputs, inits, indexingMaps, iteratorTypes);
577 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
582 linalg::PackOp maybePackedInit =
583 inits[resultNum].getDefiningOp<linalg::PackOp>();
584 if (!maybePackedInit) {
585 results.push_back(
result);
589 unPackOps.push_back(linalg::UnPackOp::create(
590 rewriter, packedLinalgOp->getLoc(),
result, maybePackedInit.getSource(),
591 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
592 results.push_back(unPackOps.back().getResult());
600 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
629 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
633 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
635 assert(tensorType == transposedValue.
getType() &&
636 "expected tensor type mismatch");
641 llvm::map_to_vector(permutation, [](
int64_t i) ->
unsigned {
return i; });
645 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
649 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
655 auto transposedGenericOp = linalg::GenericOp::create(
659 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
660 operandsRef.take_front(linalgOp.getNumDpsInputs()),
661 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
663 linalgOp.getIteratorTypesArray());
664 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
665 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
667 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
670FailureOr<PackTransposeResult>
672 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
679 linalg::PackOp transposedPackOp =
680 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
682 if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
685 OpOperand &packUse = *packOp->getUses().begin();
686 if (packUse.
getOwner() != linalgOp) {
688 linalgOp,
"not a single use by the LinalgOp target");
691 (!linalgOp.isDpsInit(&packUse) ||
692 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
694 "not produced by the LinalgOp target");
700 int64_t numLeadingDims = packOp.getSourceRank();
701 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
705 if (permutation.empty())
706 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
708 if (innerPerm.empty()) {
711 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
713 llvm::append_range(permutation,
714 llvm::map_range(innerPerm, [&](
int64_t pos) {
715 return numLeadingDims + pos;
727 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
730 linalg::UnPackOp transposedUnPackOp;
733 transposedLinalgOp->getOpOperand(packUseOperandNumber);
734 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
736 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
737 rewriter, loc, transposedResult, innerPerm, outerPerm);
739 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
743 if (packOp.hasPureTensorSemantics())
744 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
769 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
770 assert((mnkPaddedSizesNextMultipleOf.empty() ||
771 mnkPaddedSizesNextMultipleOf.size() == 3) &&
772 "num of packing sizes next multiple should be empty or of size 3");
773 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
776 int64_t numLoops = linalgOp.getNumLoops();
778 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
779 <<
" in: " << linalgOp;
781 linalgOp,
"need 3+ loops to find a matmul to pack");
785 int64_t numPackedDims = mnkPackedSizes.size();
787 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
788 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
790 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
791 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
793 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
794 paddedSizesNextMultipleOf[mnkOrder[i]] =
795 mnkPaddedSizesNextMultipleOf.empty() ? 0
796 : mnkPaddedSizesNextMultipleOf[i];
800 FailureOr<ContractionDimensions> maybeDimensions =
802 if (failed(maybeDimensions)) {
803 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
805 "couldn't infer matmul iterators");
813 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
814 kPos = maybeDimensions->k.back();
815 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
816 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
819 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
821 FailureOr<GenericOp> generalizeResult =
823 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
824 genericOp = *generalizeResult;
832 LDBG() <<
"perm: " << llvm::interleaved(permutation);
835 FailureOr<GenericOp> interchangeResult =
837 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
838 genericOp = *interchangeResult;
839 LDBG() <<
"Generalized Op to pack: " << genericOp;
856 cast<LinalgOp>(genericOp.getOperation())
857 .createLoopRanges(rewriter, genericOp.getLoc());
861 LDBG() <<
"paddedSizesNextMultipleOf: "
862 << llvm::interleaved(paddedSizesNextMultipleOf);
863 LDBG() <<
"loopRanges: "
864 << llvm::interleaved(
865 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
868 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
869 if (paddedSizesNextMultipleOf[i] == 0) {
870 adjustedPackedSizes.push_back(packedSizes[i]);
877 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
878 {loopRanges[adjustedPackedSizes.size()].size,
879 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
881 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
887 return pack(rewriter, genericOp, adjustedPackedSizes);
900 b.setInsertionPointToStart(
901 &op->getParentOfType<func::FuncOp>().getBody().front());
902 return llvm::map_to_vector<4>(tileSizes, [&](
int64_t s) {
920 auto padValue = padOp.getConstantPaddingValue();
923 if (padValue.getParentBlock() == &padOp.getRegion().front())
925 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
929 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
930 padOp.getResultType(), dynSizes);
933 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
942 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
945 rewriter, padOp.getLoc(),
946 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
950 auto resultType = padOp.getResultType();
954 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
955 if (resultType.isDynamicDim(dim)) {
957 padOp.getSource(), dim));
960 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
962 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
963 dynSizes.push_back(plusHigh);
965 staticSizes.push_back(resultType.getDimSize(dim));
970 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
971 resultType.getElementType(), dynSizes);
975 auto sourceType = padOp.getSourceType();
983 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
991 if (!sliceOp.hasUnitStride())
994 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
998 bool zeroSliceGuard =
true;
1000 if (std::optional<bool> control = controlFn(sliceOp))
1001 zeroSliceGuard = *control;
1006 FailureOr<TilingResult> tilingResult =
1008 sliceOp.getMixedSizes(), zeroSliceGuard);
1009 if (failed(tilingResult))
1012 RankedTensorType sourceType = sliceOp.getSourceType();
1013 RankedTensorType resultType = sliceOp.getResultType();
1017 if (sourceType.getRank() == resultType.getRank()) {
1018 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1024 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1026 rewriter.
replaceOp(sliceOp, rankReduced);
1036 linalg::PackOp packOp) {
1037 Value input = packOp.getSource();
1039 if (!packOp.hasPureTensorSemantics())
1042 if (!packOp.getPaddingValue()) {
1046 assert(llvm::all_of(packOp.getAllOuterDims(),
1047 [](
int64_t val) { return val == 1; }) &&
1048 "some outer dims are != 1");
1051 ShapedType inputType = packOp.getSourceType();
1052 int64_t inputRank = inputType.getRank();
1055 packOp.getDimAndTileMapping();
1062 for (
int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1065 if (!tileAndPosMapping.count(dimIdx)) {
1066 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1067 assert(inputDimSize == 1 &&
1068 "with all outer dims == 1, this non-tiled input dim should be 1!");
1069 paddedShape.push_back(inputDimSize);
1076 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1080 if (cstTileSize.has_value()) {
1081 paddedShape.push_back(cstTileSize.value());
1086 paddedShape.push_back(ShapedType::kDynamic);
1089 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1092 RankedTensorType::get(paddedShape, inputType.getElementType());
1094 false, loc, builder,
1102static SmallVector<int64_t>
1104 constexpr int64_t kNonTiledMarker = -1;
1106 for (
auto [
index, value] : llvm::enumerate(perm))
1109 vec, [&](
int64_t v) {
return v != kNonTiledMarker; });
1116static SmallVector<int64_t>
1125 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1126 if (llvm::is_contained(innerDimsPos, i)) {
1127 innerDims.push_back(dim++);
1132 outerDims.push_back(dim++);
1133 if (!outerDimsPerm.empty())
1134 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1145 rankReducedOuterDimsPerm =
1147 if (!rankReducedOuterDimsPerm.empty())
1151 perm.append(innerDims);
1159 if (!packOp.hasPureTensorSemantics())
1162 if (llvm::any_of(packOp.getTiledOuterDims(),
1163 [](
int64_t dim) { return dim != 1; })) {
1165 packOp,
"not all outer dimensions of the result are 1s");
1169 auto outerDimsPerm = packOp.getOuterDimsPerm();
1175 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](
int64_t dim) {
1176 static int prev = 0;
1178 if (llvm::is_contained(innerDimsPos, dim))
1183 if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
1184 packOp.getResult().getType().getShape()[dim] != 1))
1191 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1192 "this is not supported ATM!");
1197 int64_t srcRank = packOp.getSourceRank();
1216 for (
int64_t i = 0; i < srcRank; i++) {
1224 if (llvm::is_contained(innerDimsPos, i))
1226 srcPermForTranspose.push_back(i);
1228 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1232 ShapedType inputTy = cast<ShapedType>(input.
getType());
1234 for (
int64_t i = 0; i < srcRank; i++) {
1235 if (llvm::is_contained(innerDimsPos, i)) {
1239 if (inputTy.isStaticDim(i))
1240 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1242 shapeForEmptyOp.emplace_back(
1243 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1245 shapeForEmptyOp.append(packOp.getMixedTiles());
1252 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1254 if (auto val = llvm::dyn_cast<Value>(ofr))
1255 return getAsOpFoldResult(val);
1259 LDBG() <<
"Pack permutation: " << packOp;
1260 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1261 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1263 Value empty = tensor::EmptyOp::create(
1264 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1267 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1268 srcPermForTranspose);
1280 for (
auto size : packOp.getAllOuterDims()) {
1284 for (
auto tileSize : packOp.getMixedTiles()) {
1285 auto [_, tileSizeOfr] =
1287 writeSizes.push_back(tileSizeOfr);
1290 auto insert = tensor::InsertSliceOp::create(
1291 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1294 rewriter.
replaceOp(packOp, insert.getResult());
1301 if (!unpackOp.hasPureTensorSemantics())
1304 int64_t destRank = unpackOp.getDestRank();
1307 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1308 [](
int64_t dim) { return dim != 1; })) {
1311 "require the tiled outer dimensions of the result are all 1s");
1317 Value source = unpackOp.getSource();
1319 unpackOp.getDimAndTileMapping();
1338 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1347 if (dimAndTileMapping.count(i)) {
1348 extractSliceSizes.push_back(oneIdxAttr);
1354 if (ShapedType::isDynamic(srcShape[i])) {
1356 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1357 extractSliceSizes.push_back(dynamicDim);
1358 shapeForEmptyOp.push_back(dynamicDim);
1360 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1361 if (srcShape[i] != 1)
1362 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1366 if (srcShape[i] != 1) {
1367 readShapeForExtractSlice.push_back(srcShape[i]);
1372 auto mixedTiles = unpackOp.getMixedTiles();
1373 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1374 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1378 auto tileShape = srcShape.drop_front(destRank);
1380 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1381 Type elemType = unpackOp.getSourceType().getElementType();
1382 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1383 Value innerTile = tensor::ExtractSliceOp::create(
1384 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1388 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1394 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1396 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1402 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1403 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1404 tileSizes.push_back(
1409 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1410 transposedOp.getResult()[0], tileSizes);
1414 for (
int i = 0, idx = 0; i < destRank; ++i) {
1415 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1416 writeSizes.push_back(tileSizes[idx++]);
1418 writeSizes.push_back(oneIdxAttr);
1420 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1421 unpackOp.getDest(), writeSizes);
1422 rewriter.
replaceOp(unpackOp, insert.getResult());
1436 for (
unsigned dim : dims) {
1440 resultIndices.push_back(i);
1445 return resultIndices;
1453 auto tensorType = cast<RankedTensorType>(
tensor.getType());
1454 int64_t rank = tensorType.getRank();
1458 for (
int64_t i = 0; i < rank; ++i) {
1459 if (!llvm::is_contained(dimsToRemove, i))
1460 newShape.push_back(tensorType.getDimSize(i));
1463 auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
1471static std::optional<AffineExpr>
1475 bool onlyReferencesDroppedDims =
true;
1476 for (
unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1478 onlyReferencesDroppedDims =
false;
1482 if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](
unsigned d) {
1485 return std::nullopt;
1490 unsigned newDimIdx = 0;
1491 for (
unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1492 if (llvm::is_contained(dimsToDrop, d)) {
1506 if (failed(maybeDims))
1510 if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
1513 if (op.hasPureBufferSemantics())
1517 unsigned outSpatial0 = maybeDims->outputImage[0];
1518 unsigned outSpatial1 = maybeDims->outputImage[1];
1519 unsigned filterSpatial0 = maybeDims->filterLoop[0];
1520 unsigned filterSpatial1 = maybeDims->filterLoop[1];
1524 int64_t outSize0 = loopRanges[outSpatial0];
1525 int64_t outSize1 = loopRanges[outSpatial1];
1526 int64_t filterSize0 = loopRanges[filterSpatial0];
1527 int64_t filterSize1 = loopRanges[filterSpatial1];
1530 bool canRemoveSpatial0 = (filterSize0 == 1 && outSize0 == 1);
1531 bool canRemoveSpatial1 = (filterSize1 == 1 && outSize1 == 1);
1532 if (!canRemoveSpatial0 && !canRemoveSpatial1)
1539 if (canRemoveSpatial0) {
1540 loopDimsToRemove.push_back(outSpatial0);
1541 loopDimsToRemove.push_back(filterSpatial0);
1543 loopDimsToRemove.push_back(outSpatial1);
1544 loopDimsToRemove.push_back(filterSpatial1);
1546 llvm::sort(loopDimsToRemove);
1551 unsigned numDims = op.getNumLoops();
1552 unsigned newNumDims = numDims - loopDimsToRemove.size();
1553 for (
AffineMap map : op.getIndexingMapsArray()) {
1559 newResults.push_back(*newExpr);
1561 newMaps.push_back(
AffineMap::get(newNumDims, 0, newResults, ctx));
1566 auto iterTypes = op.getIteratorTypesArray();
1567 for (
unsigned idx = 0; idx < iterTypes.size(); ++idx) {
1568 if (!llvm::is_contained(loopDimsToRemove, idx))
1569 newIterTypes.push_back(iterTypes[idx]);
1575 for (
OpOperand *input : op.getDpsInputOperands()) {
1576 AffineMap map = op.getMatchingIndexingMap(input);
1580 tensorDimsToRemove);
1581 newInputs.push_back(reduced);
1584 OpOperand &output = *op.getDpsInitsMutable().begin();
1585 AffineMap outputMap = op.getMatchingIndexingMap(&output);
1589 outputDimsToRemove);
1594 newInputs, newOutput, newMaps, newIterTypes);
1596 newOp.getRegion().begin());
1600 LinalgOp resultOp = newOp;
1601 if (!isa<GenericOp>(op)) {
1603 if (succeeded(specializedOp))
1604 resultOp = *specializedOp;
1609 rewriter, loc, resultOp->getResult(0), output.
get());
1617struct DownscaleSizeOneWindowedConvolution final
1619 DownscaleSizeOneWindowedConvolution(
MLIRContext *context,
1623 LogicalResult matchAndRewrite(LinalgOp op,
1632 patterns.
add<DownscaleSizeOneWindowedConvolution>(patterns.
getContext(),
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.