32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
41#define DEBUG_TYPE "linalg-transforms"
60 .Case([&](scf::ForOp forOp) {
61 scf::ForOp partialIteration;
64 return partialIteration->getResults();
65 assert(!partialIteration &&
"expected that loop was not peeled");
66 return forOp->getResults();
75 for (
auto loopOp : loops)
88 if (!e.isFunctionOfDim(dim))
99 return llvm::interleaved(ri,
", ",
"|",
"");
150static FailureOr<SmallVector<std::optional<int64_t>>>
154 int64_t newDim = iteratorTypes.size();
155 iteratorTypes.push_back(iteratorTypes[dim]);
158 indexingMaps.size(), std::nullopt);
160 for (
int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
162 AffineMap map = indexingMaps[operandIdx];
165 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
173 "num results invariant violation");
175 if (!maybeOperandDimensionToPack.has_value()) {
176 newMaps.push_back(map);
181 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
187 newMaps.push_back(map);
190 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
192 indexingMaps = newMaps;
194 return packedDimPerIndexingMap;
200struct PackedOperandsDim {
201 OpFoldResult packedSize;
202 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
206struct PackedOperandsDimList {
207 void pushBack(PackedOperandsDim &&packedOperandsDims) {
208 spec.emplace_back(packedOperandsDims);
211 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
213 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
216 SmallVector<PackedOperandsDim> spec;
222 linalg::PackOp packOp,
223 bool lowerPadLikeWithInsertSlice) {
225 if (!packOp.hasPureTensorSemantics())
228 auto packedTensorType =
229 cast<RankedTensorType>(packOp->getResultTypes().front());
237 PackingMetadata packingMetadata;
260 for (
auto [pos, innerSize] :
261 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
263 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
273 rewriter, loc, map, {outerSize, origSize, innerSize});
275 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
277 packingMetadata.reassociations);
278 Value paddingValue = packOp.getPaddingValue();
280 paddingValue = arith::ConstantOp::create(
284 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
285 highs, paddingValue,
false);
287 LDBG() <<
"insertPositions: "
288 << llvm::interleaved(packingMetadata.insertPositions);
289 LDBG() <<
"outerPositions: "
290 << llvm::interleaved(packingMetadata.outerPositions);
291 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
292 LDBG() <<
"packedToStripMinedShapePerm: "
293 << llvm::interleaved(packedToStripMinedShapePerm);
294 LDBG() <<
"reassociations: "
295 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
297 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
298 LDBG() <<
"collapsed type: " << collapsed;
300 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
319 auto insertSliceOp = tensor::InsertSliceOp::create(
320 rewriter, loc, padOp, packOp.getDest(),
323 LDBG() <<
"insert_slice op: " << insertSliceOp;
325 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
333 auto expandShapeResultType =
335 auto reshapeOp = tensor::ExpandShapeOp::create(
336 rewriter, loc, expandShapeResultType, padOp.getResult(),
337 packingMetadata.reassociations, stripMinedMixedSizes);
342 auto transposeOp = linalg::TransposeOp::create(
343 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
345 LDBG() <<
"reshape op: " << reshapeOp;
346 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
347 LDBG() <<
"transpose op: " << transposeOp;
350 rewriter.
replaceOp(packOp, transposeOp->getResults());
355FailureOr<LowerUnPackOpResult>
357 bool lowerUnpadLikeWithExtractSlice) {
359 if (!unPackOp.hasPureTensorSemantics())
366 auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
367 int64_t packedRank = packedTensorType.getRank();
370 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
371 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
380 auto extractSliceOp = tensor::ExtractSliceOp::create(
381 rewriter, loc, destTensorType, unPackOp.getSource(),
385 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
388 nullptr, extractSliceOp,
394 PackingMetadata packingMetadata;
404 RankedTensorType stripMinedTensorType =
406 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
407 stripMinedTensorType, packingMetadata.reassociations);
414 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
415 stripMinedTensorType.getElementType());
417 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
418 packedToStripMinedShapePerm);
420 LDBG() <<
"insertPositions: "
421 << llvm::interleaved(packingMetadata.insertPositions);
422 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
423 LDBG() <<
"packedToStripMinedShapePerm: "
424 << llvm::interleaved(packedToStripMinedShapePerm);
425 LDBG() <<
"reassociations: "
426 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
428 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
429 LDBG() <<
"collapsed type: " << collapsedType;
432 auto reshapeOp = tensor::CollapseShapeOp::create(
433 rewriter, loc, collapsedType, transposeOp->getResult(0),
434 packingMetadata.reassociations);
437 int64_t destRank = destTensorType.getRank();
438 auto extractSliceOp = tensor::ExtractSliceOp::create(
439 rewriter, loc, destTensorType, reshapeOp->getResult(0),
445 auto copyOp = linalg::CopyOp::create(
446 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
449 rewriter.
replaceOp(unPackOp, copyOp->getResults());
456PackedOperandsDimList::extractPackedDimsForOperand(
int64_t operandPos) {
458 for (
auto &i : spec) {
459 if (!i.packedDimForEachOperand[operandPos].has_value())
461 res.push_back(i.packedDimForEachOperand[operandPos].value());
466SmallVector<OpFoldResult>
467PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
468 SmallVector<OpFoldResult> res;
469 for (
auto &i : spec) {
470 if (!i.packedDimForEachOperand[operandPos].has_value())
472 res.push_back(i.packedSize);
481 linalg::LinalgOp linalgOp,
483 if (packedSizes.size() != linalgOp.getNumLoops()) {
485 "incorrect number of pack sizes");
491 linalgOp.getIteratorTypesArray();
492 LDBG() <<
"Start packing: " << linalgOp;
493 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
494 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
499 PackedOperandsDimList listOfPackedOperandsDim;
500 for (
int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
503 if (maybeConstant.has_value() && maybeConstant.value() == 0)
506 PackedOperandsDim packedOperandsDims;
507 packedOperandsDims.packedSize = packedSizes[i];
508 FailureOr<SmallVector<std::optional<int64_t>>>
509 maybePackedDimForEachOperand =
511 if (failed(maybePackedDimForEachOperand))
513 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
515 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
516 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
517 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
518 LDBG() <<
"packedDimForEachOperand: "
519 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
521 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
527 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
529 for (
const auto &operandsList : {inputOperands, initOperands}) {
530 for (
OpOperand *opOperand : operandsList) {
531 int64_t pos = opOperand->getOperandNumber();
532 Value operand = opOperand->get();
534 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
536 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
537 LDBG() <<
"operand: " << operand;
538 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
539 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
540 if (innerPackSizes.empty()) {
541 inputsAndInits.push_back(operand);
544 Value dest = linalg::PackOp::createDestinationTensor(
545 rewriter, loc, operand, innerPackSizes, innerPos,
547 ShapedType operandType = cast<ShapedType>(operand.
getType());
548 bool areConstantTiles =
552 if (areConstantTiles && operandType.hasStaticShape() &&
553 !linalg::PackOp::requirePaddingValue(
554 operandType.getShape(), innerPos,
555 cast<ShapedType>(dest.
getType()).getShape(), {},
557 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
558 innerPos, innerPackSizes));
564 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
565 packOps.push_back(linalg::PackOp::create(
566 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
568 inputsAndInits.push_back(packOps.back().getResult());
574 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
576 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
577 auto packedLinalgOp =
578 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
579 inputs, inits, indexingMaps, iteratorTypes);
580 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
585 linalg::PackOp maybePackedInit =
586 inits[resultNum].getDefiningOp<linalg::PackOp>();
587 if (!maybePackedInit) {
588 results.push_back(
result);
592 unPackOps.push_back(linalg::UnPackOp::create(
593 rewriter, packedLinalgOp->getLoc(),
result, maybePackedInit.getSource(),
594 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
595 results.push_back(unPackOps.back().getResult());
603 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
632 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
636 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
638 assert(tensorType == transposedValue.
getType() &&
639 "expected tensor type mismatch");
644 llvm::map_to_vector(permutation, [](
int64_t i) ->
unsigned {
return i; });
648 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
652 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
658 auto transposedGenericOp = linalg::GenericOp::create(
662 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
663 operandsRef.take_front(linalgOp.getNumDpsInputs()),
664 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
666 linalgOp.getIteratorTypesArray());
667 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
668 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
670 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
673FailureOr<PackTransposeResult>
675 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
682 linalg::PackOp transposedPackOp =
683 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
685 if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
688 OpOperand &packUse = *packOp->getUses().begin();
689 if (packUse.
getOwner() != linalgOp) {
691 linalgOp,
"not a single use by the LinalgOp target");
694 (!linalgOp.isDpsInit(&packUse) ||
695 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
697 "not produced by the LinalgOp target");
703 int64_t numLeadingDims = packOp.getSourceRank();
704 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
708 if (permutation.empty())
709 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
711 if (innerPerm.empty()) {
714 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
716 llvm::append_range(permutation,
717 llvm::map_range(innerPerm, [&](
int64_t pos) {
718 return numLeadingDims + pos;
730 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
733 linalg::UnPackOp transposedUnPackOp;
736 transposedLinalgOp->getOpOperand(packUseOperandNumber);
737 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
739 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
740 rewriter, loc, transposedResult, innerPerm, outerPerm);
742 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
746 if (packOp.hasPureTensorSemantics())
747 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
772 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
773 assert((mnkPaddedSizesNextMultipleOf.empty() ||
774 mnkPaddedSizesNextMultipleOf.size() == 3) &&
775 "num of packing sizes next multiple should be empty or of size 3");
776 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
779 int64_t numLoops = linalgOp.getNumLoops();
781 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
782 <<
" in: " << linalgOp;
784 linalgOp,
"need 3+ loops to find a matmul to pack");
788 int64_t numPackedDims = mnkPackedSizes.size();
790 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
791 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
793 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
794 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
796 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
797 paddedSizesNextMultipleOf[mnkOrder[i]] =
798 mnkPaddedSizesNextMultipleOf.empty() ? 0
799 : mnkPaddedSizesNextMultipleOf[i];
803 FailureOr<ContractionDimensions> maybeDimensions =
805 if (failed(maybeDimensions)) {
806 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
808 "couldn't infer matmul iterators");
816 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
817 kPos = maybeDimensions->k.back();
818 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
819 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
822 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
824 FailureOr<GenericOp> generalizeResult =
826 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
827 genericOp = *generalizeResult;
835 LDBG() <<
"perm: " << llvm::interleaved(permutation);
838 FailureOr<GenericOp> interchangeResult =
840 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
841 genericOp = *interchangeResult;
842 LDBG() <<
"Generalized Op to pack: " << genericOp;
859 cast<LinalgOp>(genericOp.getOperation())
860 .createLoopRanges(rewriter, genericOp.getLoc());
864 LDBG() <<
"paddedSizesNextMultipleOf: "
865 << llvm::interleaved(paddedSizesNextMultipleOf);
866 LDBG() <<
"loopRanges: "
867 << llvm::interleaved(
868 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
871 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
872 if (paddedSizesNextMultipleOf[i] == 0) {
873 adjustedPackedSizes.push_back(packedSizes[i]);
880 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
881 {loopRanges[adjustedPackedSizes.size()].size,
882 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
884 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
890 return pack(rewriter, genericOp, adjustedPackedSizes);
903 b.setInsertionPointToStart(
904 &op->getParentOfType<func::FuncOp>().getBody().front());
905 return llvm::map_to_vector<4>(tileSizes, [&](
int64_t s) {
923 auto padValue = padOp.getConstantPaddingValue();
926 if (padValue.getParentBlock() == &padOp.getRegion().front())
928 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
932 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
933 padOp.getResultType(), dynSizes);
936 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
945 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
948 rewriter, padOp.getLoc(),
949 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
953 auto resultType = padOp.getResultType();
957 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
958 if (resultType.isDynamicDim(dim)) {
960 padOp.getSource(), dim));
963 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
965 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
966 dynSizes.push_back(plusHigh);
968 staticSizes.push_back(resultType.getDimSize(dim));
973 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
974 resultType.getElementType(), dynSizes);
978 auto sourceType = padOp.getSourceType();
986 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
994 if (!sliceOp.hasUnitStride())
997 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1001 bool zeroSliceGuard =
true;
1003 if (std::optional<bool> control = controlFn(sliceOp))
1004 zeroSliceGuard = *control;
1009 FailureOr<TilingResult> tilingResult =
1011 sliceOp.getMixedSizes(), zeroSliceGuard);
1012 if (failed(tilingResult))
1015 RankedTensorType sourceType = sliceOp.getSourceType();
1016 RankedTensorType resultType = sliceOp.getResultType();
1020 if (sourceType.getRank() == resultType.getRank()) {
1021 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1027 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1029 rewriter.
replaceOp(sliceOp, rankReduced);
1039 linalg::PackOp packOp) {
1040 Value input = packOp.getSource();
1042 if (!packOp.hasPureTensorSemantics())
1045 if (!packOp.getPaddingValue()) {
1049 assert(llvm::all_of(packOp.getAllOuterDims(),
1050 [](
int64_t val) { return val == 1; }) &&
1051 "some outer dims are != 1");
1054 ShapedType inputType = packOp.getSourceType();
1055 int64_t inputRank = inputType.getRank();
1058 packOp.getDimAndTileMapping();
1065 for (
int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1068 if (!tileAndPosMapping.count(dimIdx)) {
1069 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1070 assert(inputDimSize == 1 &&
1071 "with all outer dims == 1, this non-tiled input dim should be 1!");
1072 paddedShape.push_back(inputDimSize);
1079 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1083 if (cstTileSize.has_value()) {
1084 paddedShape.push_back(cstTileSize.value());
1089 paddedShape.push_back(ShapedType::kDynamic);
1092 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1095 RankedTensorType::get(paddedShape, inputType.getElementType());
1097 false, loc, builder,
1105static SmallVector<int64_t>
1107 constexpr int64_t kNonTiledMarker = -1;
1109 for (
auto [
index, value] : llvm::enumerate(perm))
1112 vec, [&](
int64_t v) {
return v != kNonTiledMarker; });
1119static SmallVector<int64_t>
1128 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1129 if (llvm::is_contained(innerDimsPos, i)) {
1130 innerDims.push_back(dim++);
1135 outerDims.push_back(dim++);
1136 if (!outerDimsPerm.empty())
1137 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1148 rankReducedOuterDimsPerm =
1150 if (!rankReducedOuterDimsPerm.empty())
1154 perm.append(innerDims);
1162 if (!packOp.hasPureTensorSemantics())
1165 if (llvm::any_of(packOp.getTiledOuterDims(),
1166 [](
int64_t dim) { return dim != 1; })) {
1168 packOp,
"not all outer dimensions of the result are 1s");
1172 auto outerDimsPerm = packOp.getOuterDimsPerm();
1178 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](
int64_t dim) {
1179 static int prev = 0;
1181 if (llvm::is_contained(innerDimsPos, dim))
1186 if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
1187 packOp.getResult().getType().getShape()[dim] != 1))
1194 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1195 "this is not supported ATM!");
1200 int64_t srcRank = packOp.getSourceRank();
1219 for (
int64_t i = 0; i < srcRank; i++) {
1227 if (llvm::is_contained(innerDimsPos, i))
1229 srcPermForTranspose.push_back(i);
1231 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1235 ShapedType inputTy = cast<ShapedType>(input.
getType());
1237 for (
int64_t i = 0; i < srcRank; i++) {
1238 if (llvm::is_contained(innerDimsPos, i)) {
1242 if (inputTy.isStaticDim(i))
1243 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1245 shapeForEmptyOp.emplace_back(
1246 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1248 shapeForEmptyOp.append(packOp.getMixedTiles());
1255 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1257 if (auto val = llvm::dyn_cast<Value>(ofr))
1258 return getAsOpFoldResult(val);
1262 LDBG() <<
"Pack permutation: " << packOp;
1263 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1264 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1266 Value empty = tensor::EmptyOp::create(
1267 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1270 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1271 srcPermForTranspose);
1283 for (
auto size : packOp.getAllOuterDims()) {
1287 for (
auto tileSize : packOp.getMixedTiles()) {
1288 auto [_, tileSizeOfr] =
1290 writeSizes.push_back(tileSizeOfr);
1293 auto insert = tensor::InsertSliceOp::create(
1294 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1297 rewriter.
replaceOp(packOp, insert.getResult());
1304 if (!unpackOp.hasPureTensorSemantics())
1307 int64_t destRank = unpackOp.getDestRank();
1310 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1311 [](
int64_t dim) { return dim != 1; })) {
1314 "require the tiled outer dimensions of the result are all 1s");
1320 Value source = unpackOp.getSource();
1322 unpackOp.getDimAndTileMapping();
1341 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1350 if (dimAndTileMapping.count(i)) {
1351 extractSliceSizes.push_back(oneIdxAttr);
1357 if (ShapedType::isDynamic(srcShape[i])) {
1359 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1360 extractSliceSizes.push_back(dynamicDim);
1361 shapeForEmptyOp.push_back(dynamicDim);
1363 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1364 if (srcShape[i] != 1)
1365 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1369 if (srcShape[i] != 1) {
1370 readShapeForExtractSlice.push_back(srcShape[i]);
1375 auto mixedTiles = unpackOp.getMixedTiles();
1376 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1377 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1381 auto tileShape = srcShape.drop_front(destRank);
1383 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1384 Type elemType = unpackOp.getSourceType().getElementType();
1385 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1386 Value innerTile = tensor::ExtractSliceOp::create(
1387 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1391 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1397 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1399 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1405 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1406 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1407 tileSizes.push_back(
1412 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1413 transposedOp.getResult()[0], tileSizes);
1417 for (
int i = 0, idx = 0; i < destRank; ++i) {
1418 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1419 writeSizes.push_back(tileSizes[idx++]);
1421 writeSizes.push_back(oneIdxAttr);
1423 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1424 unpackOp.getDest(), writeSizes);
1425 rewriter.
replaceOp(unpackOp, insert.getResult());
1439 for (
unsigned dim : dims) {
1443 resultIndices.push_back(i);
1448 return resultIndices;
1456 auto tensorType = cast<RankedTensorType>(
tensor.getType());
1457 int64_t rank = tensorType.getRank();
1461 for (
int64_t i = 0; i < rank; ++i) {
1462 if (!llvm::is_contained(dimsToRemove, i))
1463 newShape.push_back(tensorType.getDimSize(i));
1466 auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
1474static std::optional<AffineExpr>
1478 bool onlyReferencesDroppedDims =
true;
1479 for (
unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1481 onlyReferencesDroppedDims =
false;
1485 if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](
unsigned d) {
1488 return std::nullopt;
1493 unsigned newDimIdx = 0;
1494 for (
unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
1495 if (llvm::is_contained(dimsToDrop, d)) {
1509 if (failed(maybeDims))
1513 if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
1516 if (op.hasPureBufferSemantics())
1520 unsigned outSpatial0 = maybeDims->outputImage[0];
1521 unsigned outSpatial1 = maybeDims->outputImage[1];
1522 unsigned filterSpatial0 = maybeDims->filterLoop[0];
1523 unsigned filterSpatial1 = maybeDims->filterLoop[1];
1527 int64_t outSize0 = loopRanges[outSpatial0];
1528 int64_t outSize1 = loopRanges[outSpatial1];
1529 int64_t filterSize0 = loopRanges[filterSpatial0];
1530 int64_t filterSize1 = loopRanges[filterSpatial1];
1533 bool canRemoveSpatial0 = (filterSize0 == 1 && outSize0 == 1);
1534 bool canRemoveSpatial1 = (filterSize1 == 1 && outSize1 == 1);
1535 if (!canRemoveSpatial0 && !canRemoveSpatial1)
1542 if (canRemoveSpatial0) {
1543 loopDimsToRemove.push_back(outSpatial0);
1544 loopDimsToRemove.push_back(filterSpatial0);
1546 loopDimsToRemove.push_back(outSpatial1);
1547 loopDimsToRemove.push_back(filterSpatial1);
1549 llvm::sort(loopDimsToRemove);
1554 unsigned numDims = op.getNumLoops();
1555 unsigned newNumDims = numDims - loopDimsToRemove.size();
1556 for (
AffineMap map : op.getIndexingMapsArray()) {
1562 newResults.push_back(*newExpr);
1564 newMaps.push_back(
AffineMap::get(newNumDims, 0, newResults, ctx));
1569 auto iterTypes = op.getIteratorTypesArray();
1570 for (
unsigned idx = 0; idx < iterTypes.size(); ++idx) {
1571 if (!llvm::is_contained(loopDimsToRemove, idx))
1572 newIterTypes.push_back(iterTypes[idx]);
1578 for (
OpOperand *input : op.getDpsInputOperands()) {
1579 AffineMap map = op.getMatchingIndexingMap(input);
1583 tensorDimsToRemove);
1584 newInputs.push_back(reduced);
1587 OpOperand &output = *op.getDpsInitsMutable().begin();
1588 AffineMap outputMap = op.getMatchingIndexingMap(&output);
1592 outputDimsToRemove);
1597 newInputs, newOutput, newMaps, newIterTypes);
1599 newOp.getRegion().begin());
1603 LinalgOp resultOp = newOp;
1604 if (!isa<GenericOp>(op)) {
1606 if (succeeded(specializedOp))
1607 resultOp = *specializedOp;
1612 rewriter, loc, resultOp->getResult(0), output.
get());
1620struct DownscaleSizeOneWindowedConvolution final
1622 DownscaleSizeOneWindowedConvolution(
MLIRContext *context,
1626 LogicalResult matchAndRewrite(LinalgOp op,
1635 patterns.
add<DownscaleSizeOneWindowedConvolution>(patterns.
getContext(),
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr replaceDims(ArrayRef< AffineExpr > dimReplacements) const
Dim-only version of replaceDimsAndSymbols.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.